_— title: “PLAN for Chapter 1. Methodology for fitting ODE epidemic models” subtitle: “This RMarkdown is a template as initially outlined in the progress report. Each task will be outlined and then followed by a code block to be completed.” output: html_notebook: number_sections: true —
This should be similar to Chapter_01.Rmd but
with the experiment data structure improved to take advantage of more
cores than the number of scenarios.
When calibrating their vivax population transmission models, White and Champagne lack longitudinal time series data and are forced to make equilibrium assumptions at a point in time (initially from Griffin from his 2014 P.f model). Here, derivatives are set to zero and the equations solved for unknowns. In the case of Champagne, the transmission rate \(\lambda\) so that force of infection is \(\lambda\) multiplied by the sum of infectious compartments \(\sum{I}\).
This assumption does not hold for areas with moderate to strong seasonality and long-term trends. This document demonstrates the standard implementation of dynamic ODE parameter-fitting methods applied to the White/Champagne style model and extends the form of the model (and its corresponding parameter estimation routine) to provide increasingly flexible relaxations to the original uses of the White/Champagne model. The end result will be a model and fitting routine that is flexible enough to reflect vivax epidemics in non-stationary regions, unlike the original implementation.
[…]
Stan and posterior sampling will be used for all analyses.
library(R.utils)
## Loading required package: R.oo
## Loading required package: R.methodsS3
## R.methodsS3 v1.8.2 (2022-06-13 22:00:14 UTC) successfully loaded. See ?R.methodsS3 for help.
## R.oo v1.25.0 (2022-06-12 02:20:02 UTC) successfully loaded. See ?R.oo for help.
##
## Attaching package: 'R.oo'
## The following object is masked from 'package:R.methodsS3':
##
## throw
## The following objects are masked from 'package:methods':
##
## getClasses, getMethods
## The following objects are masked from 'package:base':
##
## attach, detach, load, save
## R.utils v2.12.2 (2022-11-11 22:00:03 UTC) successfully loaded. See ?R.utils for help.
##
## Attaching package: 'R.utils'
## The following object is masked from 'package:utils':
##
## timestamp
## The following objects are masked from 'package:base':
##
## cat, commandArgs, getOption, isOpen, nullfile, parse, warnings
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.2 ✔ readr 2.1.4
## ✔ forcats 1.0.0 ✔ stringr 1.5.0
## ✔ ggplot2 3.4.4 ✔ tibble 3.2.1
## ✔ lubridate 1.9.2 ✔ tidyr 1.3.0
## ✔ purrr 1.0.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ tidyr::extract() masks R.utils::extract()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(rstan)
## Loading required package: StanHeaders
##
## rstan version 2.32.3 (Stan version 2.26.1)
##
## For execution on a local, multicore CPU with excess RAM we recommend calling
## options(mc.cores = parallel::detectCores()).
## To avoid recompilation of unchanged Stan programs, we recommend calling
## rstan_options(auto_write = TRUE)
## For within-chain threading using `reduce_sum()` or `map_rect()` Stan functions,
## change `threads_per_chain` option:
## rstan_options(threads_per_chain = 1)
##
##
## Attaching package: 'rstan'
##
## The following object is masked from 'package:tidyr':
##
## extract
##
## The following object is masked from 'package:R.utils':
##
## extract
library(rstansim) # devtools::install_github("ewan-keith/rstansim")
## Loading required package: Rcpp
##
## Attaching package: 'rstansim'
##
## The following object is masked from 'package:dplyr':
##
## rename
library(parallel)
library(patchwork)
library(pbmcapply)
library(pbapply)
library(memoise)
source("../R/constants.R")
source("../R/methods.R")
n_cores = parallelly::availableCores()
options(mc.cores = n_cores)
message("Running on ", n_cores, " cores")
## Running on 10 cores
rstan_options(auto_write = TRUE)
# Store generated data here
cd = cachem::cache_disk("sim_data")
n_years = 5
n_iter = 2500 # should be at least 500
n_chains = 2
n_repetitions = 10 # how many times to duplicate each scenario
cores_per_sampler = 1 # set to n_chains if not running lots of scenarios
limit_runs = Inf # set to a finite number for testing, or Inf to run all
timelimit_per_run = 60*60 * 5
Define all models
model_champagne2022 = "stan/champagne2022.stan"
stan_model_champagne2022 = stan_model(model_champagne2022)
model_champagne2022_poisson = "stan/champagne2022_poisson.stan"
stan_model_champagne2022_poisson = stan_model(model_champagne2022_poisson)
model_champagne2022_seasonal = "stan/champagne2022_seasonal.stan"
stan_model_champagne2022_seasonal = stan_model(model_champagne2022_seasonal)
model_champagne2022_seasonal_poisson = "stan/champagne2022_seasonal_poisson.stan"
stan_model_champagne2022_seasonal_poisson = stan_model(model_champagne2022_seasonal_poisson)
model_champagne2022_seasonal_ext = "stan/champagne2022_seasonal_ext.stan"
stan_model_champagne2022_seasonal_ext = stan_model(model_champagne2022_seasonal_ext)
We begin with extending Champagne’s 2022 model for tropical vivax to include seasonality.
First we implement and verify the parameter recovery ability of Stan with the Champagne model as published in 2022.
# perform simulation study
dt = years/annual_subdivisions
t0 = -50*years
t = seq_len(n_years*annual_subdivisions) * dt
n_times = length(t)
N = 1000 # population size
#initial conditions
I_init = 0.01
y0 = c(Il=0, I0=I_init, Sl=0, S0=1-I_init, CumulativeInfections=0)
# constants for Stan
data_consts = list(n_times = n_times+1,
y0 = y0,
t0 = t0,
ts = seq_len(n_times+1) * dt,
N = N,
cases = rep(99999, n_times+1),
r = 1./60, # r
gammal = 1./223, # gammal
f = 1./72, # f
alpha = 0.21, # alpha
beta = 0.66, # beta
rho = 0.21, # rho
delta = 0,
eps = 0,
kappa = 1,
phase = 0
)
# Generate synthetic observations
real_params = list(lambda=0.02, phi_inv=0.1)
synth_data = simulate_data(
file = model_champagne2022,
data_name = "dummy_data",
input_data = data_consts,
param_values = real_params,
vars = c("ts", "sim_cases")
)
synth_data_rds = readRDS(synth_data$datasets[1])
indx <- sapply(synth_data_rds, length)
synth_df = lapply(synth_data_rds, function(x) {length(x) = max(indx); x}) %>%
as.data.frame() %>%
drop_na()
ggplot(synth_df, aes(x=ts, y=cases)) +
geom_line()
Check that we can do maximum likelihood estimation and get a very good fit.
# Fit data using L-BFGS-B
data = data_consts
data$n_times = n_times
data$ts = synth_df$ts
data$cases = synth_df$cases
optim = optimizing(stan_model_champagne2022,
data = data)
theta_init = as.list(optim$par[c("lambda", "phi_inv")]) # optimisation results
plot_data = optim$par %>%
as.data.frame() %>%
bind_cols(rownames(.), .) %>%
setNames(c("name", "value")) %>%
as_tibble() %>%
mutate(index_i = name %>% str_extract("(?<=\\[)[0-9]+") %>% as.numeric(),
time = t[index_i],
index_j = name %>% str_extract("(?<=,)[0-9]+") %>% as.numeric(),
compartment = names(y0)[index_j],
variable = name %>% str_remove("\\[.*")) %>%
mutate(variable = coalesce(compartment, variable))
plot_data %>%
drop_na(time) %>%
ggplot(aes(x=time, y=value, color=variable, group=variable)) +
geom_line() +
facet_wrap(vars(variable), scales="free_y") +
coord_cartesian(ylim = c(0, NA))
Do the full posterior fit. It should predict the data very well.
fit_nonseasonal = sampling(stan_model_champagne2022,
data = data,
iter = n_iter,
chains = n_chains,
init = rep(list(theta_init), n_chains), # Start from MLE solution
seed = 0)
# pairs(fit_nonseasonal, pars=c("lambda", "phi_inv"))
smr_pred_nonseasonal <- with(data, cbind(as.data.frame(
summary(
fit_nonseasonal,
pars = "sim_cases",
probs = c(0.05, 0.5, 0.95)
)$summary),
t=t[1:(n_times-1)], cases=synth_df$cases[1:(n_times-1)])) %>%
setNames(colnames({.}))
ggplot(smr_pred_nonseasonal, mapping = aes(x = t)) +
geom_ribbon(aes(ymin = `5%`, ymax = `95%`), fill = c_posterior, alpha = 0.35) +
geom_line(mapping = aes(y = `50%`), color = c_posterior) +
geom_point(mapping = aes(y = cases)) +
labs(x = "Day", y = "Cases") +
coord_cartesian(ylim = c(0, NA))
Check the parameter estimates, which should be almost exact.
params_extract = rstan::extract(fit_nonseasonal, c("lambda", "phi_inv")) %>%
lapply(as.numeric) %>%
as_tibble()
real_params_df = as.data.frame(real_params) %>%
pivot_longer(everything())
params_extract %>%
pivot_longer(everything()) %>%
ggplot(aes(x = value)) +
geom_density(color="steelblue", fill="steelblue", alpha=0.75) +
geom_vline(data=real_params_df, aes(xintercept=value), linetype="dashed") +
scale_x_continuous(limits = c(0, NA)) +
facet_wrap(vars(name), scales = "free") +
labs(title = "Re-estimated posterior densities for non-seasonal model",
subtitle = "Dashed line: simulated value")
Check individual epidemic trajectories. Verify that both the sampled cases and the incidence traces are in steady state.
ts_extract = rstan::extract(fit_nonseasonal, "incidence")[[1]]
n_traces = 1000
ix = sample(seq_len(dim(ts_extract)[1]), n_traces, replace=T)
ts_sample = as_tibble(t(ts_extract[ix,])) %>%
mutate(j = row_number()) %>%
pivot_longer(-j, names_to = "trace")
p1 = ggplot(ts_sample, aes(x=j, y=value, grou=trace)) +
geom_line(alpha = 0.1) +
coord_cartesian(ylim = c(0, NA)) +
labs(title = paste(n_traces, "traces of simulated incidence"))
ts_extract = rstan::extract(fit_nonseasonal, "sim_cases")[[1]]
ix = sample(seq_len(dim(ts_extract)[1]), n_traces, replace=T)
ts_sample = as_tibble(t(ts_extract[ix,])) %>%
mutate(j = row_number()) %>%
pivot_longer(-j, names_to = "trace")
p2 = ggplot(ts_sample, aes(x=j, y=value, grou=trace)) +
geom_line(alpha = 0.1) +
labs(title = paste(n_traces, "traces of simulated cases"))
p1 / p2
Then we generate data using the seasonal model.
# Generate synthetic observations
real_params = list(lambda=0.01, phi_inv=0.1)
synth_data = simulate_data(
file = model_champagne2022_seasonal,
data_name = "dummy_data",
input_data = data_consts,
param_values = real_params,
vars = c("ts", "sim_cases", "susceptible", "R0", "Rc")
)
synth_data_rds = readRDS(synth_data$datasets[1])
indx <- sapply(synth_data_rds, length)
synth_df = lapply(synth_data_rds, function(x) {length(x) = max(indx); x}) %>%
as.data.frame() %>%
drop_na()
synth_df %>%
pivot_longer(-ts) %>%
ggplot(aes(x=(ts-data_consts$t0)/years, y=value)) +
geom_line() +
facet_wrap(vars(name), scales="free_y")
Insert the new seasonal data into a data list and run the seasonal model (MLE and posterior sampling) on it.
# Fit data
# Edit data with generated values
data = data_consts
data$n_times = n_times
data$ts = synth_df$ts
data$cases = synth_df$cases
optim = optimizing(stan_model_champagne2022_seasonal,
init = lapply(real_params, function(x) {100*x}),
data = data)
# Create initial values for solving efficiency
theta_init = as.list(optim$par[c("lambda", "phi_inv")]) # optimisation results
fit_seasonal = sampling(stan_model_champagne2022_seasonal,
data = data,
iter = n_iter,
chains = n_chains,
init = rep(list(theta_init), n_chains), # Start from MLE solution
seed = 0)
pairs(fit_seasonal, pars=c("lambda", "phi_inv"))
smr_pred_seasonal <- with(synth_df, cbind(as.data.frame(
summary(
fit_seasonal,
pars = "sim_cases",
probs = c(0.05, 0.5, 0.95)
)$summary),
t=t[1:(n_times-1)], cases=cases[1:(n_times-1)])) %>%
setNames(colnames({.}))
ggplot(smr_pred_seasonal, mapping = aes(x = t)) +
geom_ribbon(aes(ymin = `5%`, ymax = `95%`), fill = c_posterior, alpha = 0.35) +
geom_line(mapping = aes(y = `50%`), color = c_posterior) +
geom_point(mapping = aes(y = cases)) +
labs(x = "Day", y = "Cases") +
coord_cartesian(ylim = c(0, NA))
Check whether the Bayesian fit with the correct model recovered the true parameters under seasonality, which it should.
real_params_df = as.data.frame(real_params) %>%
pivot_longer(everything())
rstan::extract(fit_seasonal, c("lambda", "phi_inv")) %>%
as_tibble() %>%
pivot_longer(everything()) %>%
ggplot(aes(x = value)) +
geom_density(color="steelblue", fill="steelblue", alpha=0.75) +
geom_vline(data=real_params_df, aes(xintercept=value), linetype="dashed") +
# scale_x_continuous(limits = c(0, NA)) +
facet_wrap(vars(name), scales = "free") +
labs(title = "Re-estimated posterior densities with seasonal model",
subtitle = "Dashed line: simulated value")
What if we removed seasonality from the data and tried to fit the non-seasonal model?
data_agg = aggregate_data(data)
optim = optimizing(stan_model_champagne2022,
init = lapply(real_params, function(x) {100*x}),
data = data)
theta_init = as.list(optim$par[c("lambda", "phi_inv")]) # optimisation results
fit_agg = sampling(stan_model_champagne2022,
data = data_agg,
iter = n_iter,
chains = n_chains,
init = rep(list(theta_init), n_chains), # Start from MLE solution
seed = 0)
smr_pred_agg <- with(data_agg, cbind(as.data.frame(
summary(
fit_agg,
pars = "sim_cases",
probs = c(0.05, 0.5, 0.95)
)$summary),
t=ts[1:(n_times-1)], cases=cases[1:(n_times-1)])) %>%
setNames(colnames({.}))
ggplot(smr_pred_agg, mapping = aes(x = t)) +
geom_ribbon(aes(ymin = `5%`, ymax = `95%`), fill = c_posterior, alpha = 0.35) +
geom_line(mapping = aes(y = `50%`), color = c_posterior) +
geom_point(mapping = aes(y = cases)) +
labs(x = "Day", y = "Cases") +
coord_cartesian(ylim = c(0, NA))
rstan::extract(fit_agg, c("lambda", "phi_inv")) %>%
lapply(as.numeric) %>%
as_tibble() %>%
pivot_longer(everything()) %>%
ggplot(aes(x = value)) +
geom_density(color="steelblue", fill="steelblue", alpha=0.75) +
geom_vline(data=real_params_df, aes(xintercept=value), linetype="dashed") +
scale_x_continuous(limits = c(0, NA)) +
facet_wrap(vars(name), scales = "free") +
labs(title = "Re-estimated posterior densities with non-seasonal model",
subtitle = "Dashed line: simulated value")
Under ideal conditions, model fitting using the time series should recovery the original parameters with reasonable accuracy, but the Champagne 2022 method will have some error term resulting from seasonality.
We define reasonable scenarios that reflect the assumptions of the Champagne model except for seasonality. For example, a range of transmission intensities and treatment capabilities. Data will be taken from oscillating or steady-state periods after the long-term trend has been stabilised as we are only relaxing the non-seasonal assumption of Champagne 2022.
Scenarios include sequences of scenarios. For example, low-transmission tropical relapse with a range of magnitudes of seasonality (peak-trough ratio from 0% to 100%). This sequence would demonstrate how the magnitude of seasonality impacts the accuracy of parameter recovery.
# ascertainment_rates = c(0.25, 0.5, 0.75, 1)
# radical_cure_rates = seq(0, 1, by=0.2)
radical_cure_rates = 0.66
ascertainment_rates = data_consts$alpha
seasonality_ratio = seq(0, 1, length.out=3)
# radical_cure_rates = data_consts$beta
transmission_rates = seq(0.01, 0.02, by=0.005)
importation_rate = 0 # because constant importation makes less sense in seasonal transmission
population_size = N
Expand scenarios into a grid and generate synthetic data.
data_scenarios = expand_grid(
ascertainment_rates,
radical_cure_rates,
transmission_rates,
importation_rate,
seasonality_ratio,
population_size
) %>%
mutate(scenario_ID = LETTERS[row_number()], .before=0)
.simulate_cases = function(alpha=0.5, beta=0.5, lambda=1, delta=0, eps=0, N=100, index=0) {
data = data_consts
data$alpha = alpha
data$beta = beta
data$delta = delta
data$eps = eps # 0=full seasonality, 1=no seasonality
data$N = N
real_params = list(lambda=lambda, phi_inv=0.1)
if (index == 0) {
index = sample.int(999999999, 1)
}
synth_data = suppressMessages(simulate_data(
file = model_champagne2022_seasonal,
path = "sim_data",
data_name = paste0("data_", index),
input_data = data,
param_values = real_params,
vars = c("ts", "sim_cases", "susceptible")
))
# print(synth_data$datasets[1])
synth_data_rds = readRDS(synth_data$datasets[1])
file.remove(synth_data$datasets[1])
indx <- sapply(synth_data_rds, length)
synth_df = lapply(synth_data_rds, function(x) {length(x) = max(indx); x}) %>%
as.data.frame() %>%
drop_na()
}
# Add cases onto a dataframe of scenarios based on its parameter columns
simulate_cases = function(.scenarios) {
cases_scenarios = mclapply(seq_len(nrow(.scenarios)), function(i) {
dat = .scenarios[i,]
x = .simulate_cases(dat$ascertainment_rates, dat$radical_cure_rates, dat$transmission_rates, dat$importation_rate, dat$seasonality_ratio, dat$population_size, index=i)
})
.scenarios$cases = lapply(cases_scenarios, function(x) {x$cases})
.scenarios$ts = lapply(cases_scenarios, function(x) {x$ts})
.scenarios
}
simulate_cases_memo = memoise(simulate_cases, cache=cd)
data_scenarios_sim = data_scenarios %>%
slice(rep(1:n(), each = n_repetitions)) %>%
group_by(scenario_ID) %>%
mutate(rep = row_number()) %>%
ungroup() %>%
simulate_cases_memo()
# Display scenarios
data_scenarios_sim %>%
distinct(scenario_ID, .keep_all=T) %>%
unnest(cols = c("ts", "cases")) %>%
ggplot(aes(x = ts, y = cases, color=transmission_rates, group=interaction(seasonality_ratio, transmission_rates))) +
geom_line() +
scale_color_gradient(trans = "log", breaks=10^seq(-5, 5)) +
facet_grid(vars(seasonality_ratio), vars(transmission_rates))
data_scenarios
## # A tibble: 9 × 7
## scenario_ID ascertainment_rates radical_cure_rates transmission_rates
## <chr> <dbl> <dbl> <dbl>
## 1 A 0.21 0.66 0.01
## 2 B 0.21 0.66 0.01
## 3 C 0.21 0.66 0.01
## 4 D 0.21 0.66 0.015
## 5 E 0.21 0.66 0.015
## 6 F 0.21 0.66 0.015
## 7 G 0.21 0.66 0.02
## 8 H 0.21 0.66 0.02
## 9 I 0.21 0.66 0.02
## # ℹ 3 more variables: importation_rate <dbl>, seasonality_ratio <dbl>,
## # population_size <dbl>
On each scenario or sequence, recover the parameters using Champagne’s solution on annual data and the typical ODE method. Show the resulting errors to the true parameter value.
Below: For testing
Execute methods on each scenario
data_scenarios_long = data_scenarios_sim %>%
tidyr::crossing(methods) %>%
head(limit_runs)
#' @param i index
run_scenario_method = function(i) {
out_path = file.path("../run_scenario_method", paste0("row_", i, ".rds"))
dir.create("../run_scenario_method")
if (file.exists(out_path)) {
result = read_rds(out_path)
if (!is.null(result$estimate)) {
return(result)
}
}
start = Sys.time()
.method = data_scenarios_long$method[i]
row = data_scenarios_long[i,]
est = withTimeout({
if (.method == "lambda_nonseasonal_poisson") {
poisson_nonseasonal_sol(row$cases[[1]], row$population_size, row$ascertainment_rates, row$radical_cure_rates, 1, row$transmission_rates)
}
else if (.method == "lambda_nonseasonal") {
nonseasonal_sol(row$cases[[1]], row$population_size, row$ascertainment_rates, row$radical_cure_rates, 1, transmission_rates)
}
else if (.method == "lambda_seasonal_poisson") {
poisson_seasonal_sol(row$cases[[1]], row$population_size, row$ascertainment_rates, row$radical_cure_rates, 1, row$seasonality_ratio, row$transmission_rates)
}
else if (.method == "lambda_seasonal") {
seasonal_sol(row$cases[[1]], row$population_size, row$ascertainment_rates, row$radical_cure_rates, 1, row$seasonality_ratio, row$transmission_rates)
}
else if (.method == "lambda_seasonal_ext") {
extended_seasonal_sol(row$cases[[1]], row$population_size, row$ascertainment_rates, row$radical_cure_rates, 1, row$seasonality_ratio, row$transmission_rates)
} else {
stop("Method invalid")
}
}, timeout=timelimit_per_run, onTimeout="warning")
end = Sys.time()
result = list(estimate = est, time = end - start)
write_rds(result, out_path, compress="gz")
return(result)
}
tictoc::tic()
estimates_all = pbmclapply(seq_len(nrow(data_scenarios_long)), run_scenario_method)
data_scenarios_long$estimate = lapply(estimates_all, function(x) {x$estimate})
data_scenarios_long$time = lapply(estimates_all, function(x) {x$time})
tictoc::toc()
## 0.117 sec elapsed
# data_scenarios = run_all(data_scenarios)
workspace_filename = format(Sys.time(), "%Y%M%d %H%M%S.Rdata")
# save.image(workspace_filename)
Compare densities for individual scenarios between methods
methods_count = sapply(names(comparison_colors), function(x) {
count = data_scenarios_long %>%
filter(method == x) %>%
nrow()
})
comparison_colors_count = comparison_colors
names(comparison_colors_count) = methods_count
plot_data = data_scenarios_long %>%
# pivot_longer(cols = matches("lambda"), names_to = "method", values_to = "estimate") %>%
mutate(rhat = calculate_rhat(estimate)) %>%
unnest(c(estimate, rhat)) %>%
# filter(rep == 1) %>%
# filter(rhat < 1.5) %>% # Get rid of instances that appeared to not converge
drop_na(estimate)
plot_data = plot_data %>%
group_by(scenario_ID, method, rep) %>%
mutate(trace = row_number()) %>%
ungroup()
for (id in unique(plot_data$scenario_ID)) {
.plot_cases = plot_data %>%
filter(scenario_ID == id) %>%
slice(1) %>%
unnest(c(ts, cases))
.plot_posterior = plot_data %>%
filter(scenario_ID == id) %>%
# filter(estimate > lq,
# estimate < uq) %>%
select(-scenario_ID, -ascertainment_rates, -radical_cure_rates, -population_size, -cases, -ts)
true_value = .plot_posterior$transmission_rates[1]
xlim_buffer = 0.1
limits = c(true_value * (1-xlim_buffer), true_value * (1+xlim_buffer))
limits_2 = c(true_value * (1-2*xlim_buffer), true_value * (1+2*xlim_buffer))
title = with(.plot_posterior[1,], paste0(id, ": lambda=", transmission_rates, " eps=", seasonality_ratio))
# Case plot
p1 = ggplot(.plot_cases, aes(x = ts/years, y = cases)) +
geom_point() +
labs(title = title, x = "Years")
# Posterior density
p2 = .plot_posterior %>%
filter(estimate >= limits[1], estimate <= limits[2]) %>%
ggplot(aes(x = estimate, fill = method, color = method, group=interaction(method, rep))) +
# ggplot(aes(x = estimate, y = ..scaled.., fill = method, color = method)) +
stat_density(position="identity", alpha = 0.25) +
geom_vline(aes(xintercept = true_value), linetype="dashed") +
scale_colour_manual(values = comparison_colors) +
scale_fill_manual(values = comparison_colors) +
coord_cartesian(xlim = limits)
# Trace plot
p3 = .plot_posterior %>%
group_by(method) %>%
ggplot(aes(x = trace, y = estimate, color = method, group=interaction(method, rep))) +
geom_line(alpha = 0.25) +
geom_hline(aes(yintercept = true_value), linetype="dashed") +
scale_colour_manual(values = comparison_colors) +
coord_cartesian(ylim = limits_2)
print(p1 / p2 / p3 + plot_layout(guides = "collect"))
}
Some of these chains appear to not mix.
not_run = sapply(data_scenarios_long$estimate, is.null)
data_scenarios_long$time[not_run] = -10
data_scenarios_long %>%
mutate(time = unlist(time)) %>%
ggplot(aes(x = time, fill = method)) +
geom_histogram(alpha = 0.7, binwidth=1) +
geom_vline(xintercept = -0.5) +
scale_fill_manual(values = comparison_colors) +
labs(title = "Which methods took a long time?",
subtitle = paste(sum(not_run), "scenarios terminated after timeout of", timelimit_per_run, "seconds are placed at -10")) +
facet_grid(rows = vars(seasonality_ratio),
cols = vars(transmission_rates))
rhat_threshold = 1.05
convergence = plot_data %>%
group_by(scenario_ID, rep, method) %>%
slice(1) %>%
filter(rhat > rhat_threshold)
plot_data %>%
group_by(scenario_ID, rep, method) %>%
slice(1) %>%
ggplot(aes(x = rhat, fill = method)) +
geom_histogram(alpha = 0.7, binwidth=0.1) +
scale_fill_manual(values = comparison_colors) +
labs(title = "Did the runs converge?",
subtitle = paste(nrow(convergence), "runs had an rhat on lambda greater than", rhat_threshold)) +
facet_grid(rows = vars(seasonality_ratio),
cols = vars(transmission_rates))
## Warning: Removed 13 rows containing non-finite values (`stat_bin()`).
How bad are the errors?
errors = data_scenarios_long %>%
unnest(estimate) %>%
mutate(error = as.numeric((estimate - transmission_rates)/transmission_rates)^2) %>%
group_by(transmission_rates, seasonality_ratio, method) %>%
summarise(rmse = sqrt(mean(error)), .groups="drop")
ggplot(errors, aes(x = seasonality_ratio, y = rmse, color = method)) +
geom_line() +
facet_grid(rows = vars(transmission_rates), scales="free") +
labs(subtitle = "Facet by transmission rate")
What do mean predictions look like?
While we previously fit only the transmission rate \(\lambda\) (and overdispersion \(1/\phi\) for negative-binomial models), the
seasonal models introduce more complexity; here, we have parameterised
seasonality by three parameters: peak/trough ratio \(\epsilon\), ‘sharpness’ \(\kappa\), and an offset from the start of
the year to the to the peak phase (because \(\phi\) is already taken). These parameters
could either be implemented as equivalent to transmission suitability as
a function of temperature/precipitation (as in Mordecai et al., or the
MAP), or estimated simultaneously with the transmission parameters.
theta_init_ext = theta_init
theta_init_ext$eps = 0.5
theta_init_ext$kappa = 1
theta_init_ext$phase = 0.01
fit_seasonal_ext = sampling(stan_model_champagne2022_seasonal_ext,
data = data,
iter = 100,
chains = n_chains,
init = rep(list(theta_init_ext), n_chains), # Start from MLE solution
seed = 0)
Plot diagnostics
rstan::extract(fit_seasonal_ext, c("lambda", "phi_inv", "eps", "kappa", "phase")) %>%
lapply(as.numeric) %>%
as_tibble() %>%
pivot_longer(everything()) %>%
ggplot(aes(x = value)) +
geom_density(color="steelblue", fill="steelblue", alpha=0.75) +
geom_vline(data=real_params_df, aes(xintercept=value), linetype="dashed") +
scale_x_continuous(limits = c(0, NA)) +
facet_wrap(vars(name), scales = "free") +
labs(title = "Re-estimated posterior densities with non-seasonal model",
subtitle = "Dashed line: simulated value")
smr_pred_ext <- with(synth_df, cbind(as.data.frame(
summary(
fit_seasonal_ext,
pars = "sim_cases",
probs = c(0.05, 0.5, 0.95)
)$summary),
t=t[1:(n_times-1)],
cases=cases[1:(n_times-1)])) %>%
setNames(colnames({.}))
ggplot(smr_pred_ext, mapping = aes(x = t)) +
geom_ribbon(aes(ymin = `5%`, ymax = `95%`), fill = c_posterior, alpha = 0.35) +
geom_line(mapping = aes(y = `50%`), color = c_posterior) +
geom_point(mapping = aes(y = cases)) +
labs(x = "Day", y = "Cases") +
coord_cartesian(ylim = c(0, NA))
How is goodness of fit under repetition?
# Sum of squares
loss_fn = function(est_value, true_value) {
mean((est_value - true_value)^2)
}
In particu
We test whether parameter recovery works on real datasets. We will find data from a variety of settings (e.g., transmission levels, remoteness, strains) to demonstrate generalisability.
Data:
data_hainan = list(NULL)
data_brazil = list(low = NULL,
med = NULL,
high = NULL)
The parameter fitting routing will be run for each real-world dataset.
fit_brazil = lapply(data_brazil, function(x) {
# perform Bayesian fit
})
# Show visual diagnostics to demonstrate that the model outputs actually reflect the observed data (this is not guaranteed because Champagne never used seasonal data)
# Show parameter posterior distributions
With the parameter posterior distributions, we can start to discuss outcomes. We expect that the main advancement here is demonstrating heterogeneity in epidemic parameters; while it is clear there is heterogeneity in prevalence, administrative data is not normally able to demonstrate this. As James highlighted, this also raises the question of whether multiple regions should comprise a hierarchical model. The Brazilian data is very comprehensive across regions and makes this logical, but I suggest exploring this after performing independent modelling first.
The Chinese and Brazilian datasets include longitudinal data for long enough to observe trends that must be due to changes in policy or environment Therefore, the seasonal Champagne model certainly will not be sufficient to explain this variation.
We will modify the Champagne model’s transmission rate \(\lambda\) to be time-varying, \(\lambda(t)\). The functional form of \(\lambda(t)\) is unclear but is constrained by the information available to fit it. For example, we will never know if a decrease in trend is due to natural decay to a low equilibrium, or due to a decreasing \(\lambda(t)\). However, by finding a reasonable form, \(\lambda(t)\) will account for long-term variation that cannot be explained by other parameters and we hope it will allow the other parameters to be recovered in scenarios where there is no form of static or dynamic steady-state at play.
At a minimum, \(\lambda(t)\) will be a piecewise-constant or piecewise-linear function with breakpoints set manually when there are obvious changes in transmission intensity. This may prove sufficient. Other suggestions include particle filtering methods such as Kalman filtering.
lambda = function(t, ...) {
}
_